import json
import time
from collections.abc import Generator, Mapping
from typing import Any, Optional, cast

import pydantic
from dify_plugin.entities.agent import AgentInvokeMessage
from dify_plugin.entities.model.llm import LLMModelConfig, LLMUsage
from dify_plugin.entities.model.message import (
    AssistantPromptMessage,
    PromptMessage,
    SystemPromptMessage,
    UserPromptMessage,
    PromptMessageTool,
)
from dify_plugin.entities.tool import (
    LogMetadata,
    ToolInvokeMessage,
    ToolParameter,
    ToolProviderType,
)
from dify_plugin.interfaces.agent import (
    AgentModelConfig,
    AgentScratchpadUnit,
    AgentStrategy,
    ToolEntity,
)
from pydantic import BaseModel

from output_parser.cot_output_parser import CotAgentOutputParser
from prompt.template import REACT_PROMPT_TEMPLATES
from utils.mcp_client import McpClients

ignore_observation_providers = ["wenxin"]


class ReActParams(BaseModel):
    query: str
    instruction: str
    model: AgentModelConfig
    tools: list[ToolEntity] | None
    mcp_servers_config: str | None
    mcp_resources_as_tools: bool = False
    mcp_prompts_as_tools: bool = False
    maximum_iterations: int = 3


class AgentPromptEntity(BaseModel):
    """
    Agent Prompt Entity.
    """

    first_prompt: str
    next_iteration: str


class ReActAgentStrategy(AgentStrategy):
    def __init__(self, runtime, session):
        super().__init__(runtime, session)
        self.query = ""
        self.instruction = ""
        self.history_prompt_messages = []
        self.prompt_messages_tools = []
        # 记录当前是否是第一轮对话
        self._is_first_round = True
        # 记录规划阶段确定需要使用的工具
        self._planned_tools = set()
        # 记录当前对话中实际调用过的工具
        self._tools_used_in_conversation = set()
        # 存储所有可用工具的完整信息
        self._all_available_tools = []

    @property
    def _user_prompt_message(self) -> UserPromptMessage:
        return UserPromptMessage(content=self.query)

    @property
    def _system_prompt_message(self) -> SystemPromptMessage:
        templates = REACT_PROMPT_TEMPLATES["english"]["chat"]
        
        # 根据当前阶段选择不同的提示词模板
        if self._is_first_round:
            prompt_template = templates["prompt"]  # 使用规划阶段的模板
        else:
            prompt_template = templates["execution"]  # 使用执行阶段的模板
            
        if not prompt_template:
            raise ValueError("Agent prompt configuration is not set")

        # 智能工具信息管理
        tools_info = []
        
        if self._is_first_round:
            # 第一轮规划：提供所有工具的完整信息,等待模型规划
            for tool in self._prompt_messages_tools:
                tool_info = tool.model_dump(mode="json")
                tools_info.append(tool_info)
            self._all_available_tools = list(self._prompt_messages_tools)  # 暂存所有工具信息
            self._is_first_round = False  # 标记第一轮结束
        else:
            # 后续执行轮次：仅提供本轮预计会用到的工具的完整信息
            for tool in self._prompt_messages_tools:
                if tool.name in self._planned_tools:
                    # 规划使用的工具：提供完整信息
                    tool_info = next((t for t in self._all_available_tools if t.name == tool.name), None)
                    if tool_info:
                        tools_info.append(tool_info.model_dump(mode="json"))
                    else:
                        # 如果在规划阶段没有完整信息,就从当前工具获取
                        tools_info.append(tool.model_dump(mode="json"))
        
        # 构建系统提示词
        system_prompt = (
            prompt_template.replace("{{instruction}}", self.instruction)
            .replace("{{query}}", self.query if self._is_first_round else "")
            .replace(
                "{{tools}}",
                json.dumps(tools_info, ensure_ascii=False)
            )
            .replace(
                "{{tool_names}}",
                ", ".join([tool.name for tool in self._prompt_messages_tools]),
            )
        )

        return SystemPromptMessage(content=system_prompt)

    def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
        """
        Run ReAct agent application
        """

        try:
            react_params = ReActParams(**parameters)
        except pydantic.ValidationError as e:
            raise ValueError(f"Invalid parameters: {e!s}") from e

        # Init parameters
        self.query = react_params.query
        self.instruction = react_params.instruction or self.instruction
        agent_scratchpad = []
        iteration_step = 1
        max_iteration_steps = react_params.maximum_iterations
        run_agent_state = True
        llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
        final_answer = ""
        prompt_messages = []

        # Init model
        model = react_params.model
        stop = (
            react_params.model.completion_params.get("stop", [])
            if react_params.model.completion_params
            else []
        )
        if (
                "Observation" not in stop
                and model.provider not in ignore_observation_providers
        ):
            stop.append("Observation")

        # Init prompts
        self.history_prompt_messages = model.history_prompt_messages

        # convert tools into ModelRuntime Tool format
        tools = react_params.tools
        tool_instances = {tool.identity.name: tool for tool in tools} if tools else {}

        # Fetch MCP tools
        mcp_clients = None
        mcp_tools = []
        mcp_tool_instances = {}
        mcp_servers_config = react_params.mcp_servers_config
        mcp_resources_as_tools = react_params.mcp_resources_as_tools
        mcp_prompts_as_tools = react_params.mcp_prompts_as_tools
        if mcp_servers_config:
            try:
                # Injected variable mcp_servers_config begin and end has double quotes.
                servers_config = json.loads(mcp_servers_config.strip('"'))
            except json.JSONDecodeError as e:
                raise ValueError(f"mcp_servers_config must be a valid JSON string: {e}")
            mcp_clients = McpClients(servers_config, mcp_resources_as_tools, mcp_prompts_as_tools)
            mcp_tools = mcp_clients.fetch_tools()
            mcp_tool_instances = {tool.get("name"): tool for tool in mcp_tools} if mcp_tools else {}

        react_params.model.completion_params = (
                react_params.model.completion_params or {}
        )
        # convert tools into ModelRuntime Tool format
        prompt_messages_tools = self._init_prompt_tools(tools)
        prompt_messages_tools.extend(self._init_prompt_mcp_tools(mcp_tools))
        self._prompt_messages_tools = prompt_messages_tools

        while run_agent_state and iteration_step <= max_iteration_steps:
            # continue to run until there is not any tool call
            run_agent_state = False
            round_started_at = time.perf_counter()
            round_log = self.create_log_message(
                label=f"ROUND {iteration_step}",
                data={},
                metadata={
                    LogMetadata.STARTED_AT: round_started_at,
                },
                status=ToolInvokeMessage.LogMessage.LogStatus.START,
            )
            yield round_log
            if iteration_step == max_iteration_steps:
                # the last iteration, remove all tools
                self._prompt_messages_tools = []

            message_file_ids: list[str] = []

            # recalc llm max tokens
            prompt_messages = self._organize_prompt_messages(
                agent_scratchpad, self.query
            )
            if model.entity and model.completion_params:
                self.recalc_llm_max_tokens(
                    model.entity, prompt_messages, model.completion_params
                )
            # invoke model
            chunks = self.session.model.llm.invoke(
                model_config=LLMModelConfig(**model.model_dump(mode="json")),
                prompt_messages=prompt_messages,
                stream=True,
                stop=stop,
            )

            usage_dict = {}
            react_chunks = CotAgentOutputParser.handle_react_stream_output(
                chunks, usage_dict
            )
            scratchpad = AgentScratchpadUnit(
                agent_response="",
                thought="",
                action_str="",
                observation="",
                action=None,
            )

            model_started_at = time.perf_counter()
            model_log = self.create_log_message(
                label=f"{model.model} Thought",
                data={},
                metadata={
                    LogMetadata.STARTED_AT: model_started_at,
                    LogMetadata.PROVIDER: model.provider,
                },
                parent=round_log,
                status=ToolInvokeMessage.LogMessage.LogStatus.START,
            )
            yield model_log

            for chunk in react_chunks:
                if isinstance(chunk, AgentScratchpadUnit.Action):
                    action = chunk
                    # detect action
                    assert scratchpad.agent_response is not None
                    scratchpad.agent_response += json.dumps(chunk.model_dump())

                    scratchpad.action_str = json.dumps(chunk.model_dump())
                    scratchpad.action = action
                else:
                    scratchpad.agent_response = scratchpad.agent_response or ""
                    scratchpad.thought = scratchpad.thought or ""
                    scratchpad.agent_response += chunk
                    scratchpad.thought += chunk
            scratchpad.thought = (
                scratchpad.thought.strip()
                if scratchpad.thought
                else "I am thinking about how to help you"
            )
            agent_scratchpad.append(scratchpad)

            # get llm usage
            if "usage" in usage_dict:
                if usage_dict["usage"] is not None:
                    self.increase_usage(llm_usage, usage_dict["usage"])
            else:
                usage_dict["usage"] = LLMUsage.empty_usage()

            action = (
                scratchpad.action.to_dict()
                if scratchpad.action
                else {"action": scratchpad.agent_response}
            )

            yield self.finish_log_message(
                log=model_log,
                data={"thought": scratchpad.thought, **action},
                metadata={
                    LogMetadata.STARTED_AT: model_started_at,
                    LogMetadata.FINISHED_AT: time.perf_counter(),
                    LogMetadata.ELAPSED_TIME: time.perf_counter() - model_started_at,
                    LogMetadata.PROVIDER: model.provider,
                    LogMetadata.TOTAL_PRICE: usage_dict["usage"].total_price
                    if usage_dict["usage"]
                    else 0,
                    LogMetadata.CURRENCY: usage_dict["usage"].currency
                    if usage_dict["usage"]
                    else "",
                    LogMetadata.TOTAL_TOKENS: usage_dict["usage"].total_tokens
                    if usage_dict["usage"]
                    else 0,
                },
            )
            if not scratchpad.action:
                final_answer = scratchpad.thought
            else:
                if scratchpad.action.action_name.lower() == "final answer":
                    # action is final answer, return final answer directly
                    try:
                        if isinstance(scratchpad.action.action_input, dict):
                            final_answer = json.dumps(scratchpad.action.action_input)
                        elif isinstance(scratchpad.action.action_input, str):
                            final_answer = scratchpad.action.action_input
                        else:
                            final_answer = f"{scratchpad.action.action_input}"
                    except json.JSONDecodeError:
                        final_answer = f"{scratchpad.action.action_input}"
                else:
                    run_agent_state = True
                    # action is tool call, invoke tool
                    tool_call_started_at = time.perf_counter()
                    tool_name = scratchpad.action.action_name
                    tool_call_log = self.create_log_message(
                        label=f"CALL {tool_name}",
                        data={},
                        metadata={
                            LogMetadata.STARTED_AT: time.perf_counter(),
                            LogMetadata.PROVIDER: tool_instances[
                                tool_name
                            ].identity.provider
                            if tool_instances.get(tool_name)
                            else "",
                        },
                        parent=round_log,
                        status=ToolInvokeMessage.LogMessage.LogStatus.START,
                    )
                    yield tool_call_log
                    tool_invoke_response, tool_invoke_parameters = (
                        self._handle_invoke_action(
                            action=scratchpad.action,
                            tool_instances=tool_instances,
                            mcp_clients=mcp_clients,
                            mcp_tool_instances=mcp_tool_instances,
                            message_file_ids=message_file_ids,
                        )
                    )
                    scratchpad.observation = tool_invoke_response
                    scratchpad.agent_response = tool_invoke_response
                    yield self.finish_log_message(
                        log=tool_call_log,
                        data={
                            "tool_name": tool_name,
                            "tool_call_args": tool_invoke_parameters,
                            "output": tool_invoke_response,
                        },
                        metadata={
                            LogMetadata.STARTED_AT: tool_call_started_at,
                            LogMetadata.PROVIDER: tool_instances[
                                tool_name
                            ].identity.provider
                            if tool_instances.get(tool_name)
                            else "",
                            LogMetadata.FINISHED_AT: time.perf_counter(),
                            LogMetadata.ELAPSED_TIME: time.perf_counter()
                                                      - tool_call_started_at,
                        },
                    )

                # update prompt tool message
                for prompt_tool in self._prompt_messages_tools:
                    if prompt_tool.name in tool_instances:
                        self.update_prompt_message_tool(
                            tool_instances[prompt_tool.name], prompt_tool
                        )
            yield self.finish_log_message(
                log=round_log,
                data={
                    "action_name": scratchpad.action.action_name
                    if scratchpad.action
                    else "",
                    "action_input": scratchpad.action.action_input
                    if scratchpad.action
                    else "",
                    "thought": scratchpad.thought,
                    "observation": scratchpad.observation,
                },
                metadata={
                    LogMetadata.STARTED_AT: round_started_at,
                    LogMetadata.FINISHED_AT: time.perf_counter(),
                    LogMetadata.ELAPSED_TIME: time.perf_counter() - round_started_at,
                    LogMetadata.TOTAL_PRICE: usage_dict["usage"].total_price
                    if usage_dict["usage"]
                    else 0,
                    LogMetadata.CURRENCY: usage_dict["usage"].currency
                    if usage_dict["usage"]
                    else "",
                    LogMetadata.TOTAL_TOKENS: usage_dict["usage"].total_tokens
                    if usage_dict["usage"]
                    else 0,
                },
            )
            iteration_step += 1

        # All MCP Client close
        if mcp_clients:
            mcp_clients.close()

        yield self.create_text_message(final_answer)
        yield self.create_json_message(
            {
                "execution_metadata": {
                    LogMetadata.TOTAL_PRICE: llm_usage["usage"].total_price
                    if llm_usage["usage"] is not None
                    else 0,
                    LogMetadata.CURRENCY: llm_usage["usage"].currency
                    if llm_usage["usage"] is not None
                    else "",
                    LogMetadata.TOTAL_TOKENS: llm_usage["usage"].total_tokens
                    if llm_usage["usage"] is not None
                    else 0,
                }
            }
        )

    def _organize_user_query(
            self, query, prompt_messages: list[PromptMessage]
    ) -> list[PromptMessage]:
        """
        Organize user query
        """
        prompt_messages.append(UserPromptMessage(content=query))

        return prompt_messages

    def _organize_prompt_messages(
        self, agent_scratchpad: list, query: str
    ) -> list[PromptMessage]:
        """
        Organize prompt messages, minimize token usage and avoid redundant context.
        """
        # organize system prompt
        system_message = self._system_prompt_message

        # Only keep the latest N assistant scratchpad units (e.g., last 2), to reduce token usage
        N = 2
        assistant_messages = []
        if agent_scratchpad:
            # Only keep the last N scratchpad units
            recent_scratchpad = agent_scratchpad[-N:]
            assistant_message = AssistantPromptMessage(
                content=self._format_assistant_message(recent_scratchpad)
            )
            assistant_messages = [assistant_message]

        # Only add the latest user query (not all previous queries)
        query_messages = [UserPromptMessage(content=query)]

        # Only keep the latest M historic messages (e.g., last 2), to avoid context bloat
        M = 2
        historic_messages = self.history_prompt_messages[-M:] if self.history_prompt_messages else []

        # Compose the final message list
        messages = [system_message, *historic_messages, *query_messages]
        if assistant_messages:
            messages.extend(assistant_messages)

        return messages

    def _handle_invoke_action(
            self,
            action: AgentScratchpadUnit.Action,
            mcp_clients: McpClients | None,
            tool_instances: Mapping[str, ToolEntity],
            mcp_tool_instances: Mapping[str, dict],
            message_file_ids: list[str],
    ) -> tuple[str, dict[str, Any] | str]:
        """
        handle invoke action
        :param action: action
        :param mcp_clients: MCP Clients
        :param tool_instances: tool instances
        :param mcp_tool_instances: MCP tool instances
        :param message_file_ids: message file ids
        :param trace_manager: trace manager
        :return: observation, meta
        """
        # action is tool call, invoke tool
        tool_call_name = action.action_name
        tool_call_args = action.action_input
        tool_instance = tool_instances.get(tool_call_name)
        mcp_tool_instance = mcp_tool_instances.get(tool_call_name)

        # 处理工具调用
        if tool_call_name == "Plan":
            try:
                # 确保是有效的规划结果
                if not isinstance(tool_call_args, dict):
                    if isinstance(tool_call_args, str):
                        try:
                            tool_call_args = json.loads(tool_call_args)
                        except json.JSONDecodeError:
                            return "规划结果格式错误：必须是有效的JSON对象", tool_call_args
                    else:
                        return "规划结果格式错误：必须是字典或JSON字符串", tool_call_args
                
                # 处理第一轮规划结果
                self._handle_planning_result(tool_call_args)
                
                # 返回处理后的计划概要
                steps_summary = "\n".join([
                    f"步骤 {i+1}: {step['step']}\n使用工具: {', '.join(step['tools'])}"
                    for i, step in enumerate(tool_call_args.get('steps', []))
                ])
                return f"规划已完成,执行计划如下:\n\n{steps_summary}", tool_call_args
            except Exception as e:
                print(f"规划处理错误: {e}")
                return (f"规划格式错误：{str(e)}\n请按照正确的格式重新规划:\n" +
                       "{\n  \"steps\": [\n    {\"step\": \"步骤描述\", \"tools\": [\"工具1\", \"工具2\"]},\n    ...\n  ]\n}",
                       tool_call_args)
        
        # 记录工具使用信息
        if not self._is_first_round:
            # 执行阶段,记录实际使用的工具
            self._tools_used_in_conversation.add(tool_call_name)
        
        if not tool_instance and not mcp_tool_instance:
            answer = f"there is not a tool named {tool_call_name}"
            return answer, tool_call_args

        if isinstance(tool_call_args, str):
            try:
                tool_call_args = json.loads(tool_call_args)
            except json.JSONDecodeError as e:
                params = [
                    param.name
                    for param in tool_instance.parameters
                    if param.form == ToolParameter.ToolParameterForm.LLM
                ]
                if len(params) > 1:
                    raise ValueError("tool call args is not a valid json string") from e
                tool_call_args = {params[0]: tool_call_args} if len(params) == 1 else {}

        tool_invoke_parameters = {}
        try:
            if mcp_tool_instance:
                # invoke MCP tool
                tool_invoke_parameters = tool_call_args
                content = mcp_clients.execute_tool(
                    tool_name=tool_call_name,
                    tool_args=tool_invoke_parameters,
                )
                if len(content) == 1:
                    item = content[0]
                    if item["type"] == "text":
                        result = item["text"]
                    elif item["type"] in ("image", "video"):
                        result = json.dumps(item, ensure_ascii=False)
                    elif item["type"] == "resource":
                        result = json.dumps(item['resource'], ensure_ascii=False)
                    else:
                        result = json.dumps(item, ensure_ascii=False)
                else:
                    result = json.dumps(content, ensure_ascii=False)
            else:
                # invoke tool
                tool_invoke_parameters = {**tool_instance.runtime_parameters, **tool_call_args}
                tool_invoke_responses = self.session.tool.invoke(
                    provider_type=ToolProviderType(tool_instance.provider_type),
                    provider=tool_instance.identity.provider,
                    tool_name=tool_instance.identity.name,
                    parameters=tool_invoke_parameters,
                )
                result = ""
                for response in tool_invoke_responses:
                    if response.type == ToolInvokeMessage.MessageType.TEXT:
                        result += cast(ToolInvokeMessage.TextMessage, response.message).text
                    elif response.type == ToolInvokeMessage.MessageType.LINK:
                        result += (
                                f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
                                + " please tell user to check it."
                        )
                    elif response.type in {
                        ToolInvokeMessage.MessageType.IMAGE_LINK,
                        ToolInvokeMessage.MessageType.IMAGE,
                    }:
                        result += (
                                "image has been created and sent to user already, "
                                + "you do not need to create it, just tell the user to check it now."
                        )
                    elif response.type == ToolInvokeMessage.MessageType.JSON:
                        text = json.dumps(
                            cast(
                                ToolInvokeMessage.JsonMessage, response.message
                            ).json_object,
                            ensure_ascii=False,
                        )
                        result += f"tool response: {text}."
                    else:
                        result += f"tool response: {response.message!r}."
        except Exception as e:
            result = f"tool invoke error: {e!s}"

        return result, tool_invoke_parameters

    def _handle_planning_result(self, action_input: dict) -> None:
        """
        处理第一轮规划的结果
        :param action_input: 规划结果，包含步骤和每步需要的工具
        """
        planned_tools = set()
        try:
            # 验证输入格式
            if not isinstance(action_input, dict):
                raise ValueError("规划结果必须是一个字典")
            
            steps = action_input.get('steps')
            if not isinstance(steps, list) or not steps:
                raise ValueError("规划结果必须包含非空的steps列表")
            
            # 处理每个步骤
            for i, step in enumerate(steps, 1):
                if not isinstance(step, dict):
                    raise ValueError(f"步骤 {i} 必须是一个字典")
                
                if 'step' not in step or 'tools' not in step:
                    raise ValueError(f"步骤 {i} 必须包含 'step' 和 'tools' 字段")
                
                if not isinstance(step['tools'], list):
                    raise ValueError(f"步骤 {i} 的tools必须是一个列表")
                
                # 验证工具名称是否有效
                invalid_tools = [
                    tool for tool in step['tools'] 
                    if not any(t.name == tool for t in self._all_available_tools)
                ]
                if invalid_tools:
                    raise ValueError(f"步骤 {i} 包含无效的工具: {invalid_tools}")
                
                planned_tools.update(step['tools'])
            
            # 只保留规划用到的工具的完整信息
            self._all_available_tools = [
                tool for tool in self._all_available_tools 
                if tool.name in planned_tools
            ]
            self._planned_tools = planned_tools
            
            print("\n=== Round 1 规划结果 ===")
            print("执行步骤:")
            for i, step in enumerate(steps, 1):
                print(f"\n步骤 {i}: {step['step']}")
                print(f"所需工具: {step['tools']}")
            print(f"\n已规划使用的工具: {self._planned_tools}")
            print("===========================\n")
        except Exception as e:
            print(f"规划结果解析错误: {e}")
            # 解析失败时保留所有工具以确保功能正常
            self._planned_tools = {tool.name for tool in self._all_available_tools}

    def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
        """
        convert dict to action
        """
        return AgentScratchpadUnit.Action(
            action_name=action["action"], action_input=action["action_input"]
        )

    def _format_assistant_message(
        self, agent_scratchpad: list[AgentScratchpadUnit]
    ) -> str:
        """
        Format assistant message for scratchpad,只保留精简的Thought/Action/Observation内容，避免token冗余。
        """
        message = ""
        for scratchpad in agent_scratchpad:
            # 只保留最关键的信息
            if scratchpad.thought:
                message += f"Thought: {scratchpad.thought}\n"
            if scratchpad.action:
                # 只展示action name和input
                message += (
                    f"Action: {{\n  \"action\": \"{scratchpad.action.action_name}\", "
                    f"\"action_input\": {json.dumps(scratchpad.action.action_input, ensure_ascii=False)}\n}}\n"
                )
            if scratchpad.observation:
                # observation只保留前200字符
                obs = scratchpad.observation
                if len(obs) > 200:
                    obs = obs[:200] + "..."
                message += f"Observation: {obs}\n"
        return message.strip()

    @staticmethod
    def _init_prompt_mcp_tools(mcp_tools: list[dict]) -> list[PromptMessageTool]:
        """
        Initialize prompt message MCP tools
        """
        prompt_messages_tools = []

        for tool in mcp_tools:
            prompt_message = PromptMessageTool(
                name=tool.get("name"),
                description=tool.get("description", ""),
                parameters=tool.get("inputSchema"),
            )
            prompt_messages_tools.append(prompt_message)

        return prompt_messages_tools
